#' Helper function to create a customized metric
#'
#' @export
mx.metric.custom <- function(name, feval) {
  init <- function() {
    c(0, 0)
  }
  update <- function(label, pred, state) {
    m <- feval(as.array(label), as.array(pred))
    state <- c(state[[1]] + 1, state[[2]] + m)
    return(state)
  }
  get <- function(state) {
    list(name=name, value=(state[[2]]/state[[1]]))
  }
  ret <- (list(init=init, update=update, get=get))
  class(ret) <- "mx.metric"
  return(ret)
}

#' Accuracy metric for classification
#'
#' @export
mx.metric.accuracy <- mx.metric.custom("accuracy", function(label, pred) {
  ypred = max.col(t(as.array(pred)), tie="first")
  return(sum((as.array(label) + 1) == ypred) / length(label))
})

#' Helper function for top-k accuracy
is.num.in.vect <- function(vect, num){
  resp <- any(is.element(vect, num))
  return(resp)
}

#' Top-k accuracy metric for classification
#'
#' @export
mx.metric.top_k_accuracy <- mx.metric.custom("top_k_accuracy", function(label, pred, top_k = 5) {
  if(top_k == 1){
    return(mx.metric.accuracy(label,pred))
  } else{
    ypred <- apply(pred,2,function(x) order(x, decreasing=TRUE)[1:top_k])
    ans <- apply(ypred, 2, is.num.in.vect, num = as.array(label + 1))
    acc <- sum(ans)/length(label)  
    return(acc)
  }
})

#' MSE (Mean Squared Error) metric for regression
#'
#' @export
mx.metric.mse <- mx.metric.custom("mse", function(label, pred) {
  res <- mean((label-pred)^2)
  return(res)
})
    
#' RMSE (Root Mean Squared Error) metric for regression
#'
#' @export
mx.metric.rmse <- mx.metric.custom("rmse", function(label, pred) {
  res <- sqrt(mean((label-pred)^2))
  return(res)
})

#' MAE (Mean Absolute Error) metric for regression
#'
#' @export
mx.metric.mae <- mx.metric.custom("mae", function(label, pred) {
  res <- mean(abs(label-pred))
  return(res)
})

#' RMSLE (Root Mean Squared Logarithmic Error) metric for regression
#'
#' @export
mx.metric.rmsle <- mx.metric.custom("rmsle", function(label, pred) {
  res <- sqrt(mean((log(pred + 1) - log(label + 1))^2))
  return(res)
})

